1762F - Good Pairs - CodeForces Solution

binary search data structures dp *2600

Python Code:

import re
import functools
import random
import sys
import os
import math
from collections import Counter, defaultdict, deque
from functools import lru_cache, reduce
from itertools import accumulate, combinations, permutations
from heapq import nsmallest, nlargest, heappushpop, heapify, heappop, heappush
from io import BytesIO, IOBase
from copy import deepcopy
import threading
from typing import *
from operator import add, xor, mul, ior, iand, itemgetter
import bisect
BUFSIZE = 4096
inf = float('inf')

class FastIO(IOBase):
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)

class IOWrapper(IOBase):
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable
        self.write = lambda s: self.buffer.write(s.encode("ascii"))
        self.read = lambda: self.buffer.read().decode("ascii")
        self.readline = lambda: self.buffer.readline().decode("ascii")

sys.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")

def I():
    return input()

def II():
    return int(input())

def MII():
    return map(int, input().split())

def LI():
    return list(input().split())

def LII():
    return list(map(int, input().split()))

def GMI():
    return map(lambda x: int(x) - 1, input().split())

def LGMI():
    return list(map(lambda x: int(x) - 1, input().split()))

from types import GeneratorType
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    to = next(to)
                    if not stack:
                    to = stack[-1].send(to)
            return to
    return wrappedfunc

class SegmentTree:
    def __init__(self, n, merge):
        self.n = n
        self.tree = [5 * 10 ** 5] * (2 * self.n)
        self._merge = merge

    def query(self, ql, qr):
        lans = rans = 5 * 10 ** 5
        ql += self.n
        qr += self.n + 1
        while ql < qr:
            if ql % 2:
                lans = self._merge(lans, self.tree[ql])
                ql += 1
            if qr % 2:
                qr -= 1
                rans = self._merge(rans, self.tree[qr])
            ql //= 2
            qr //= 2
        return self._merge(lans, rans)

    def update(self, index, value):
        index += self.n
        self.tree[index] = value
        while index:
            index //= 2
            self.tree[index] = self._merge(self.tree[2 * index], self.tree[2 * index + 1])

class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * n

    def sum(self, r):
        res = 0
        while r >= 0:
            res += self.bit[r]
            r = (r & (r + 1)) - 1
        return res

    def rsum(self, l, r):
        return self.sum(r) - self.sum(l - 1)

    def add(self, idx, delta):
        while idx < self.n:
            self.bit[idx] += delta
            idx = idx | (idx + 1)

def solve(nums, k):
    n = len(nums)
    dp = [0] * n
    for i in range(n-1, -1, -1):
        if nums[i] < 10 ** 5:
            idx = seg.query(nums[i] + 1, min(10 ** 5 - 1, nums[i] + k))
            if idx < n: dp[i] = dp[idx] + fenwick.rsum(nums[i] + 1, nums[idx])
        fenwick.add(nums[i], 1)
        seg.update(nums[i], i)
    ans = sum(dp)
    for i in range(n):
        fenwick.add(nums[i], -1)
        seg.update(nums[i], 5 * 10 ** 5)
    return ans

t = II()
seg = SegmentTree(10 ** 5, lambda x, y: x if x <= y else y)
fenwick = FenwickTree(10 ** 5)
count = [0] * (10 ** 5)
for _ in range(t):
    n, k = MII()
    nums = LGMI()
    ans = 0
    for num in nums:
        count[num] += 1
        ans += count[num]
    if k: ans += solve(nums, k) + solve(nums[::-1], k)
    for num in nums: count[num] -= 1


